import argparse
from run_setups import train_run, train_no_Stack_run, test_run, johnson_run, test_run_visualize


def run(config_dict):
    if config_dict["experiment_type"] == "train" or config_dict["experiment_type"] == "train_with_cost_perturbation":
        train_run(config_dict)
    if config_dict["experiment_type"] == "johnson":
        johnson_run(config_dict)
    if config_dict["experiment_type"] == "train_no_Stackelberg":
        train_no_Stack_run(config_dict)
    if config_dict["experiment_type"] == "test":
        config_dict["test_type"] = 'standard'
        test_run(config_dict)
    if config_dict["experiment_type"] == "test_visualize":
        config_dict["test_type"] = 'standard'
        test_run_visualize(config_dict)


if __name__=="__main__":

    parser = argparse.ArgumentParser(description="AI collusion")

    parser.add_argument(
        "--seed",
        type=int,
        default=12345,
        help="The random number generator seed; default: 12345",
    )

    parser.add_argument(
        "--gamma",
        type=lambda x: float(x),
        default=1.0,
        help="Level of buybox intervention: 0 no buybox, 1 buybox; default: 1"
    )

    parser.add_argument(
        "--dp_type",
        type=str,
        default="learn_threshold",
        help="Buybox type; default: learn_threshold; options: learn_threshold, pdp, dpdp, no_intervene"
    )

    parser.add_argument(
        "--max_steps",
        type=int,
        default=50000000,
        help="Number of training steps until run stops"
    )

    parser.add_argument(
        "--algorithm",
        type=str,
        default="A2C",
        help="DeepRL algorithm used to train platform policy"
    )

    parser.add_argument(
        "--tot_num_reward_steps",
        type=int,
        default=30,
        help="Number of reward steps in Stackelberg POMDP"
    )

    parser.add_argument(
        "--tot_num_eq_steps",
        type=int,
        default=50000,
        help="Number of equilibrium steps in Stackelberg POMDP"
    )

    parser.add_argument(
        "--bbox_state_space_type",
        type=str,
        default="price_profile",
        help="Defines type of buybox. Options are price_profile, no_state"
    )

    parser.add_argument(
        "--frac_excluded_eq_steps",
        type=lambda x: float(x),
        default=0.0,
        help="Sets the fraction of equilibrium steps that are not exposed to the learner"
    )

    parser.add_argument(
        "--reward_step_random_price_prob",
        type=lambda x: float(x),
        default=0.0,
        help="Sets the probability that an evaluation step happens with random prices"
    )

    parser.add_argument(
        "--price_grid_length",
        type=int,
        default=5,
        help="Sets the number of discrete prices we consider"
    )

    parser.add_argument(
        "--experiment_type",
        type=str,
        default="train",
        help="Type of experiment to run. Options: train, test, test_visualize, johnson, train_no_Stackelberg, train_with_cost_perturbation"
    )

    parser.add_argument(
        "--q_restart_rate",
        type=lambda x: float(x),
        default=-1,
        help="In expectation, each agent restarts its exploration rate q_restart_rate times per episode; if -1, we restart at the beginning of each episode"
    )

    parser.add_argument(
        "--marginal_cost",
        type=lambda x: float(x),
        default=1.0,
        help="The marginal cost for each agent"
    )

    parser.add_argument(
        "--grid_upper_bound",
        type=lambda x: float(x),
        default=2.1,
        help="Price grid upper bound; default: 2.1 as in Johnson at al.'"
    )

    parser.add_argument(
        "--critic_obs",
        type=str,
        default="none",
        help="Determines which additional info is given to critic network. Options: flag, none, Q_matrix"
    )

    args = parser.parse_args()

    config_dict = vars(args)

    run(config_dict)
